iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 29
0

今天我們要實作GAN,但不像以前的AutoEncoder Model,GAN大部分是使用捲積層,而非像之前使用的全連結層,所以經過網路上大師們的建議,由於在訓練GAN會很不穩定,因此一些Layers和激活函數都要特別的注意。今天的實作我是參考https://colab.research.google.com/drive/1hNMJ1C3ARYud-6UDqKGYx12cGZ9ULDZ-#scrollTo=YgH_d6fNVuEw

Generator

Generator主要是圖片產生器,透過tf.layers.Conv2DTranspose,把特徵還原成照片

class Generator(keras.Model):
  def __init__(self):
    super(Generator,self).__init__()
    #encoder
    self.fc_layer_1 = layers.Dense(3*3*512)
    self.conv_1 = layers.Conv2DTranspose(256,3,3,'valid')

    self.bn_1 = layers.BatchNormalization()       
    self.conv_2 = layers.Conv2DTranspose(128,5,2,'valid')
    self.bn_2 = layers.BatchNormalization()     
    self.conv_3 = layers.Conv2DTranspose(3,4,3,'valid')

  def call(self, inputs, training=None):
    x = self.fc_layer_1(inputs)
    x = tf.reshape(x,[-1,3,3,512])
    x = tf.nn.relu(x)
    x = self.bn_1(self.conv_1(x),training=training)
    x = self.bn_2(self.conv_2(x),training=training)
    x = self.conv_3(x)
    x = tf.tanh(x)
    return x

Discriminator

他是一個圖片分類器,用以判斷 Generator 產生圖片的好壞

class Discriminator(keras.Model):
  def __init__(self):
    super(Discriminator,self).__init__()

    self.conv_1 = layers.Conv2D(64,5,3,'valid')
    self.conv_2 = layers.Conv2D(128,5,3,'valid')
    self.bn_1 = layers.BatchNormalization()
    self.conv_3 = layers.Conv2D(256,5,3,'valid')    
    self.bn_2 = layers.BatchNormalization()
    self.flatten = layers.Flatten()
    self.fc_layer = layers.Dense(1)
![https://ithelp.ithome.com.tw/upload/images/20201009/20130246QRoobECSo6.png](https://ithelp.ithome.com.tw/upload/images/20201009/20130246QRoobECSo6.png)
  
  def call(self, inputs, training=None):
    x = tf.nn.leaky_relu(self.conv_1(inputs))    
    x = tf.nn.leaky_relu(self.bn_1(self.conv_2(x),training=training))    
    x = tf.nn.leaky_relu(self.bn_2(self.conv_3(x),training=training))  
    x = self.flatten(x)
    x = self.fc_layer(x)
    return x

測試

g = Generator()
d = Discriminator()
x = tf.random.normal([1,64,64,3])
z = tf.random.normal([1,100])
prob = g(x)
print(prob)
out = d(x)
print(out.shape)

https://ithelp.ithome.com.tw/upload/images/20201009/20130246U0mdE2Qbu9.png
今天只是簡單的GAN模型建立,明天才會把圖畫出來,終於要完成30天發文了!!!!!。


上一篇
GAN(生成對抗網路)
下一篇
GAN實作(二)
系列文
Tensorflow2.030
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言